#!/usr/bin/env python
import enum
import glob
import json
import os
import subprocess
import sys
import tempfile
import time
from contextlib import contextmanager
from importlib import import_module
from pathlib import Path
from typing import Generator, List, Optional

import typer
import yaml

DEFAULT_ROOT = Path("~/src/github.com/PrefectHQ").expanduser()

app = typer.Typer(help="Utilities for managing Prefect collections en masse")


def prefect_collections() -> List[str]:
    entries = []

    for file in glob.glob("docs/integrations/catalog/*.yaml"):
        if file.lower().endswith("/template.yaml"):
            continue

        with open(file) as f:
            entries.append(yaml.safe_load(f))

    entries = [e for e in entries if e["author"] == "Prefect"]
    return [e["collectionName"] for e in entries]


def ensure_repo(root: Path, collection: str) -> Path:
    repo = root / collection
    if not (repo / ".git").exists():
        subprocess.check_call(
            [
                "git",
                "clone",
                f"git@github.com:PrefectHQ/{collection}.git",
                str(repo),
            ]
        )

    return repo


def repositories(root: Path) -> Generator[Path, None, None]:
    if not root.exists():
        raise typer.BadParameter(f"Root directory {root} does not exist")

    for collection in prefect_collections():
        yield ensure_repo(root, collection)


def _shell(repo: Path, command: str):
    env = os.environ.copy()
    env["COLLECTION"] = repo.name
    subprocess.check_call(command, shell=True, cwd=str(repo))


def _in_env(repo: Path, command: str):
    env = os.environ.copy()
    env["COLLECTION"] = repo.name
    env["PYENV_VERSION"] = repo.name
    env["PYENV_DIR"] = str(repo)
    subprocess.check_call(f"pyenv exec {command}", shell=True, cwd=str(repo), env=env)


@app.command()
def initialize_python_environment(python_version: str, root: Path = DEFAULT_ROOT):
    """Sets up each repo's Python environment using pyenv and pip.  Expects that you
    have pyenv configured locally already."""
    for repo in repositories(root):
        if not (repo / ".python-version").exists():
            _shell(repo, f"pyenv virtualenv {python_version} {repo.name}")
            _shell(repo, f"pyenv local {repo.name}")

        _in_env(repo, "pip install -e '.[dev]'")
        _in_env(repo, "pre-commit install")


def _git(repo: Path, *args: str, quiet: bool = False) -> str:
    return subprocess.check_output(
        ["git", *args],
        cwd=str(repo),
        text=True,
        stderr=subprocess.DEVNULL if quiet else None,
    )


@contextmanager
def _default_profile() -> Generator[None, None, None]:
    import prefect.context

    original = prefect.context.get_settings_context().profile.name
    if original == "default":
        yield
        return

    _shell(Path(), "prefect profile use default")
    try:
        yield
    finally:
        _shell(Path(), f"prefect profile use {original}")


@app.command()
def clone(root: Path = DEFAULT_ROOT):
    """Clones all collections beneath the given root directory, and ensures they
    are all up-to-date and checked out to the `main` branch."""
    for repo in repositories(root):
        _git(repo, "checkout", "main")
        _git(repo, "pull")


@app.command()
def branch(branch: str, root: Path = DEFAULT_ROOT):
    """Creates a new branch on all collection repositories"""
    for repo in repositories(root):
        _git(repo, "checkout", "-b", branch)


@app.command()
def reset(root: Path = DEFAULT_ROOT):
    """Resets the files on all repositories back to the HEAD of their branch"""
    for repo in repositories(root):
        _git(repo, "checkout", ".")


@app.command()
def status(root: Path = DEFAULT_ROOT):
    """Reports the status of the working tree in all repositories"""
    for repo in repositories(root):
        _git(repo, "status")


@app.command()
def diff(root: Path = DEFAULT_ROOT):
    """Reports a diff of all repositories"""
    for repo in repositories(root):
        _git(repo, "--no-pager", "diff")


@app.command()
def commit(message: str, root: Path = DEFAULT_ROOT):
    """Commits all changes on all repositories"""
    for repo in repositories(root):
        diff = _git(repo, "diff", "--name-only")
        if not diff.strip():
            continue

        _git(repo, "add", ".")
        _git(repo, "commit", "-m", message)


@app.command()
def pull_request(root: Path = DEFAULT_ROOT, dry_run: bool = False):
    """Creates a pull request on all repositories"""
    with tempfile.NamedTemporaryFile("w+t") as pr_body_file:
        pr_body_file.write("The title\n\nThe body\n")
        pr_body_file.flush()

        if "EDITOR" not in os.environ:
            print("No EDITOR variable set")
            raise typer.Exit(1)

        subprocess.check_call(f"$EDITOR {pr_body_file.name}", shell=True)

        pr_body_file.seek(0)
        body = pr_body_file.read().strip()
        if not body:
            print("No pull request body provided")
            raise typer.Exit(1)

        title, _, body = body.partition("\n")
        title = title.strip()
        body = body.strip()

        pr_body_file.seek(0)
        pr_body_file.truncate(0)
        pr_body_file.write(body)
        pr_body_file.flush()

        DELAY = 10
        print(
            (
                "Note: there is an intentional delay between each repo to avoid"
                " exceeding Github's rate limit for creating pull requests"
            ),
            file=sys.stderr,
        )

        for repo in repositories(root):
            print(repo)
            time.sleep(DELAY)

            diff = _git(repo, "diff", "origin/main", "--name-only")
            if not diff.strip():
                continue

            branch_name = _git(repo, "rev-parse", "--abbrev-ref", "HEAD").strip()
            _git(repo, "push", "origin", branch_name)

            arguments = [
                "gh",
                "pr",
                "create",
                "--repo",
                f"PrefectHQ/{repo.name}",
                "--base",
                "main",
                "--title",
                title,
                "--body-file",
                pr_body_file.name,
            ]
            if dry_run:
                print(" ".join(arguments))
                continue

            try:
                subprocess.check_output(arguments, cwd=str(repo), text=True)
            except subprocess.CalledProcessError as exc:
                if "already exists" in exc.stdout:
                    continue


@app.command()
def test(root: Path = DEFAULT_ROOT, forge_ahead: bool = True):
    """Runs tests on all repositories"""
    with _default_profile():
        for repo in repositories(root):
            try:
                _in_env(repo, "pytest tests")
            except subprocess.CalledProcessError:
                if not forge_ahead:
                    raise


class BumpLevel(enum.Enum):
    MAJOR = "major"
    MINOR = "minor"
    PATCH = "patch"


@app.command()
def bump_versions(level: BumpLevel, root: Path = DEFAULT_ROOT, dry_run: bool = False):
    """Bumps the version of all collections to the latest version of Prefect"""
    DELAY = 10
    print(
        (
            "Note: there is an intentional delay between each repo to avoid exceeding "
            "Github's rate limit for creating releases"
        ),
        file=sys.stderr,
    )

    for repo in repositories(root):
        time.sleep(DELAY)

        _git(repo, "checkout", "main", quiet=True)
        _git(repo, "pull")

        try:
            previous_release: Optional[str] = subprocess.check_output(
                ["gh", "release", "view", "--json", "tagName"],
                cwd=str(repo),
                text=True,
            )
            assert previous_release
            previous_release = json.loads(previous_release)["tagName"]
            assert isinstance(previous_release, str)
        except subprocess.CalledProcessError as exc:
            if "release not found" in exc.stderr:
                previous_release = None
            else:
                raise

        if previous_release:
            since_previous_release = _git(
                repo, "log", f"{previous_release}...origin/main", "--pretty=oneline"
            )
            if not since_previous_release:
                print(f"{repo.name} has no commits since {previous_release}")
                continue

        if not previous_release:
            current_major, current_minor, current_patch = 0, 0, 0
        else:
            assert previous_release.startswith("v")
            parts = [part for part in previous_release[1:].split(".", maxsplit=4)]
            current_major, current_minor, current_patch = [int(p) for p in parts[:3]]

        if level == BumpLevel.MAJOR:
            current_major += 1
            current_minor, current_patch = 0, 0
        elif level == BumpLevel.MINOR:
            current_minor += 1
            current_patch = 0
        elif level == BumpLevel.PATCH:
            current_patch += 1

        next_release = f"v{current_major}.{current_minor}.{current_patch}"

        print(f"{repo.name} {previous_release} -> {next_release}")

        arguments = [
            "gh",
            "release",
            "create",
            next_release,
            "--generate-notes",
        ]
        if dry_run:
            print(" ".join(arguments))
            continue

        subprocess.check_call(arguments, cwd=str(repo), text=True)


@app.command()
def run_command(command: str, root: Path = DEFAULT_ROOT):
    """Runs the given command on all repositories.  Make sure to quote the command
    if it has more than one word!  The environment variable `COLLECTION` will be set
    to the name of the current collection, and the command will be run in the
    collection repository's directory."""
    for repo in repositories(root):
        _shell(repo, command)


@app.command()
def run_function(function_path: str, root: Path = DEFAULT_ROOT):
    """Runs the given Python function (specified as a module path) on all repositories.
    The Python function must be importable from the current directory (not the
    repository directory).  For example, with a module like this:

        # a_path/my_lister.py
        import glob
        from pathlib import Path

        def main(repo: Path):
            for file in glob.glob("*.py"):
                print(file)

    You can run it like this:

        collections-manager run-function a_path.my_lister:main

    """

    module_name, _, function_name = function_path.partition(":")
    if not module_name or not function_name:
        raise typer.BadParameter(
            "function_name must be specified as a module path followed by a colon and "
            "a function name, like my_module.my_nested_module:my_function"
        )

    # This lets us drop a script in the root of the repo and pick it up as a module
    sys.path.insert(0, ".")

    module = import_module(module_name)
    function = getattr(module, function_name)

    for repo in repositories(root):
        function(repo)


if __name__ == "__main__":
    app()
